import abc
import random
from itertools import chain
from typing import Sequence, Any, Tuple

from gym import spaces as spaces

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv


def add_posn(ab, cd):
    a, b = ab
    c, d = cd
    return (a + c, b + d)


def flatten(iterator_to_flatten):
    return list(chain.from_iterable(iterator_to_flatten))


class FastGridWorldDirectional(MultiAgentSafetyEnv, abc.ABC):
    """
    A grid world with many pre-computed properties for extremely fast steps
    (well, as fast as you can get with python)

    Each agent has a position and a direction. It has the following cone of visibility
    (assuming the agent is facing up):

    The observations are as such:
    0 = Empty
    1 = Filled with another agent
    2 = Wall

    This environment generates a single AP, representing if any two agents have collided with each other.

    Agents have four actions.
    0 = Do nothing
    1 = Move forward
    2 = Turn left (counterclockwise)
    3 = Turn right (clockwise)

    Agent directions:
    0 = Facing up
    1 = Right
    2 = Down
    3 = Left
    """

    def __init__(self, grid_posns, num_agents, ending_idx):
        # Create and cache a bunch of useful arrays so that we don't need to recalculate them every time step
        self.grid_posns = grid_posns

        # posn -> idx
        self.grid_posn_inv = {pos: idx for idx, pos in enumerate(self.grid_posns)}

        # dir, idx -> idx: What position after moving forward one step
        self.loc_dir_next = [
            [self.grid_posn_inv.get(add_posn(pos, offset), idx) for idx, pos in enumerate(self.grid_posns)]
            for offset in [(0, -1), (1, 0), (0, 1), (-1, 0)]]

        self.num_agents = num_agents

        assert len(ending_idx) == num_agents

        self.goal_idx = ending_idx

    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        return [self.state_space()] * 3

    def agent_actions_spaces(self) -> Sequence[spaces.Space]:
        return [spaces.Discrete(4)] * 3

    def state_space(self) -> spaces.Space:
        return spaces.MultiDiscrete([len(self.grid_posns), 4] * self.num_agents)

    def initial_state(self):
        starting_locs = random.sample(range(len(self.grid_posns)), self.num_agents)
        starting_dirs = random.choices(range(4), k=self.num_agents)

        state = flatten(zip(starting_locs, starting_dirs))
        return state, self.project_obs(state)

    def get_next_loc_dir(self, loc, dir, act):
        if act == 0:
            return loc, dir
        elif act == 1:
            return self.loc_dir_next[dir][loc], dir
        elif act == 2:
            return loc, (dir - 1) % 4
        else:
            return loc, (dir + 1) % 4

    def step(self, environment_state, joint_action: Sequence[Any]) -> Tuple[
        Any, Sequence[Any], Sequence[float], bool, bool]:
        assert len(environment_state) == self.num_agents * 2
        locs_and_dirs = list(
            (environment_state[i], environment_state[i + 1]) for i in range(len(environment_state), step=2))

        new_locs_and_dirs = list(
            self.get_next_loc_dir(loc, dir, act) for (loc, dir), act in zip(locs_and_dirs, joint_action))

        def check_collisions_and_crossings(idx1, idx2):
            loc1 = locs_and_dirs[idx1][0]
            loc2 = locs_and_dirs[idx2][0]
            nloc1 = new_locs_and_dirs[idx1][0]
            nloc2 = new_locs_and_dirs[idx2][0]

            return (nloc1 == nloc2) or (nloc1 == loc2 and nloc2 == loc1)

        collisions_or_crossings = any(
            any(check_collisions_and_crossings(idx1, idx2) for idx2 in range(idx1))
            for idx1 in range(self.num_agents)
        )

        rewards = [1 if loc == goal else 0 for (loc, _), goal in zip(new_locs_and_dirs, self.goal_idx)]
        done = sum(rewards) == self.num_agents

        return new_locs_and_dirs, self.project_obs(new_locs_and_dirs), rewards, done, (not collisions_or_crossings)

    def project_obs(self, state) -> Sequence[Any]:
        return [state] * self.num_agents
